
## IMPORT

from pyspark.context import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window, types
from pyspark.sql.types import StructType, IntegerType
from pyspark.sql.context import SQLContext
import pyspark.sql.functions as f

from math import floor, ceil
import sys
import pandas as pd

from splink.spark.spark_linker import SparkLinker
from splink.spark.spark_comparison_library import jaro_winkler_at_thresholds, levenshtein_at_thresholds, exact_match
from splink.charts import save_offline_chart

## SPARK SESSION 

sc = SparkContext.getOrCreate(conf=SparkConf())
spark = SparkSession(sc)
spark.sparkContext.setLogLevel('WARN')
sc.setCheckpointDir("splink_sandbox/temp_graphframes/")
spark.udf.registerJavaFunction('jaro_winkler','uk.gov.moj.dash.linkage.JaroWinklerSimilarity',types.DoubleType())

## LOAD DATA FROM SLURM PARAMS

year = str([2008, 2010, 2012, 2014, 2016, 2018, 2020][int(sys.argv[1])-1])

cf = spark.read.parquet(f"data/fec_clean/cycle_table={year}/")

cf = cf.filter(f"po_num IS NOT NULL AND incl_in_dedupe = 1")
cf = cf.repartition(int(ceil(cf.count()/1000)))
print(f"There are {cf.count()} cf rows.")

## WRITE MODEL SETTINGS

no_comp = f" AND abs(l.gender - r.gender) < 1.001"

br1 = "l.contbr_zip = r.contbr_zip AND l.contbr_nm_first = r.contbr_nm_first" + no_comp
br2 = "l.contbr_zip = r.contbr_zip AND l.contbr_nm_last = r.contbr_nm_last" + no_comp
br3 = "l.contbr_st = r.contbr_st AND l.contbr_nm_last = r.contbr_nm_last AND substring(l.contbr_nm_first,1,2) = substring(r.contbr_nm_first,1,2) AND l.contbr_city = r.contbr_city" + no_comp
br4 = "l.contbr_st = r.contbr_st AND l.contbr_nm_first = r.contbr_nm_first AND substring(l.contbr_nm_last,1,2) = substring(r.contbr_nm_last,1,2) AND l.contbr_city = r.contbr_city" + no_comp
br5 = "l.contbr_st = r.contbr_st AND l.po_num = r.po_num AND substring(l.contbr_nm_first,1,1) = substring(r.contbr_nm_first,1,1) AND l.contbr_city = r.contbr_city" + no_comp
br6 = "l.contbr_st = r.contbr_st AND l.po_num = r.po_num AND substring(l.contbr_nm_last,1,1) = substring(r.contbr_nm_last,1,1) AND l.contbr_city = r.contbr_city" + no_comp

settings = {
  
  "link_type": "dedupe_only",
  
  "blocking_rules_to_generate_predictions": [
    {"blocking_rule": br1, "salting_partitions": 10},
    {"blocking_rule": br2, "salting_partitions": 10},
    {"blocking_rule": br3, "salting_partitions": 10},
    {"blocking_rule": br4, "salting_partitions": 10},
    {"blocking_rule": br5, "salting_partitions": 1},
    {"blocking_rule": br6, "salting_partitions": 1}
  ],
  
  "comparisons": [
    jaro_winkler_at_thresholds(
      "contbr_name_fm",
      [0.94, 0.90, 0.86],
      include_exact_match_level = True,
      term_frequency_adjustments = True
    ),
    jaro_winkler_at_thresholds(
      "contbr_name_ls",
      [0.94, 0.90, 0.86],
      include_exact_match_level = True,
      term_frequency_adjustments = True
    ),
    levenshtein_at_thresholds(
      "po_num",
      [1,2],
      include_exact_match_level = True,
      term_frequency_adjustments = False
    ),
    jaro_winkler_at_thresholds(
      "contbr_city",
      [0.94, 0.88],
      include_exact_match_level = False,
      term_frequency_adjustments = True
    ),
  ],
  
  "retain_matching_columns": True,
  "retain_intermediate_calculation_columns": False,
  "unique_id_column_name": "emmid",
  "max_iterations": 40
  
}

## TRAIN MODEL

print(f"Training this year address model with {cf.count()} rows.")
linker = SparkLinker(input_table_or_tables = cf, settings_dict = settings, spark = spark)

linker.estimate_u_using_random_sampling(target_rows = cf.count())

linker.estimate_parameters_using_expectation_maximisation(
  "l.contbr_zip = r.contbr_zip AND l.contbr_nm_first = r.contbr_nm_first",
  comparisons_to_deactivate = ["contbr_city", "contbr_name_fm"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("contbr_city") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
    ._get_comparison_by_output_column_name("contbr_name_fm") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.estimate_parameters_using_expectation_maximisation(
  "l.contbr_zip = r.contbr_zip AND l.contbr_nm_last = r.contbr_nm_last",
  comparisons_to_deactivate = ["contbr_city", "contbr_name_ls"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("contbr_city") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
    ._get_comparison_by_output_column_name("contbr_name_ls") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.estimate_parameters_using_expectation_maximisation(
  "l.contbr_st = r.contbr_st AND l.contbr_nm_last = r.contbr_nm_last AND l.contbr_nm_first = r.contbr_nm_first",
  comparisons_to_deactivate = ["contbr_name_ls", "contbr_name_fm"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("contbr_name_ls") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
    ._get_comparison_by_output_column_name("contbr_name_fm") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.save_settings_to_json(f"logs/fec_dedupe_{year}_p.json", overwrite = True)
save_offline_chart(linker.match_weights_chart().spec, f"logs/fec_dedupe_{year}_p.html", overwrite = True)

## PREDICT MATCHES AND SAVE

# linker = SparkLinker(cf, spark = spark)
# linker.load_settings_from_json(f"logs/fec_dedupe_{year}_p.json")

df = linker.predict(threshold_match_probability = 0.5) \
  .as_spark_dataframe() \
  .select("match_probability", "emmid_l", "emmid_r", "gamma_contbr_name_ls", "gamma_contbr_name_fm", "gamma_po_num") \
  .filter("(gamma_contbr_name_ls > 0 AND gamma_contbr_name_fm > 0) OR (gamma_contbr_name_ls = -1 AND gamma_contbr_name_fm = 4) OR (gamma_contbr_name_ls = 4 AND gamma_contbr_name_fm = -1)")

print(f"There are {df.count()} predicted matches with posterior >= 0.5")
df.coalesce(10).write.mode("overwrite").parquet(f"data/fec_dedupe_matches/{year}_p")
